from src.bulletEnv import BulletEnv
import os
import time
import utils
from stable_baselines3 import SAC
import pickle
import torch
import numpy as np
import random

class Simulator(object):

    def __init__(self, init_config, goal_config, problem_config, term_sampler,mp = False,seed=1337):
        print("Initializing simulator")
        self.actions = []
        self.problem_config = problem_config
        self.runbag = self.setup_bag(init_config, goal_config)        
        # self.set_seed(seed)
        if mp:
            robot_key = "mp_robot"
            if "mp_gui" not in problem_config or not problem_config["mp_gui"]:
                gui = False
            else:
                gui = True
            # gui = True
        else:
            robot_key = "sim_robot"
            gui = problem_config["simulator_gui"]
        self.env = BulletEnv(seed=seed,
                            gui=gui,
                            init_set=init_config,
                            term_set=goal_config,
                            term_sampler = term_sampler,
                            demo_mode=self.problem_config['debug'],
                            robot_config=self.problem_config[robot_key],
                            env_path=os.path.join(self.problem_config['env_path'],
                                                  self.problem_config['env_name']+".stl"),              
                            forked=False,
                            envid=robot_key,
                            max_ep_len=self.problem_config["region_policy"]['max_ep_len']) 
        self.env.env_mode = 'eval'
        self.obs = self.reset()

    def set_init_sample_and_eval_func(self,init_sampler,eval_func,term_sampler):
        print("Resetting sampler and eval func")
        self.reset_action_log()
        self.env.set_sampler_and_eval_func(init_sampler,eval_func,term_sampler)
        # self.env.reset()
    
    def sync_simulator(self, init_config):
        print("Syncing simulator")
        # return self.env.setRobotState(init_config,sim=False)
        return self.env.setFullRobotState(init_config,sim=False)
    
    def reset(self):
        self.env.env_mode = 'train'
        self.obs = self.env.reset()
        self.env.env_mode = 'eval'
        return self.obs

    def get_collision_fn(self):
        def collision_fn(pose):
            return self.env.collision_at_pose(pose)
        return collision_fn
    
    def plot(self,pose):
        self.env._p.addUserDebugLine([pose[0],pose[1],0.0],[pose[0],pose[1],0.5],[0,0,1])

        

    def execute_policy(self, policy_path, option_guide, option_switch_point):
        # Run agent in eval mode
        print("Executing policy...")
        model = SAC.load(policy_path)
        # np.random.seed()
        # random.seed()

        self.env.info = utils.misc.create_env_info_dict()
        self.env.update_targets(option_guide, option_switch_point)

        action_count = 0
        time_end = time.time()+60
        cum_reward = 0
        while action_count < self.env.max_ep_len:
            action, _states = model.predict(self.obs)
            self.actions.append(action)
            self.obs, reward, done, info = self.env.step(action)
            action_count += 1
            cum_reward += reward
            if done:
                break
        print("Finished policy exec")
        self.env.done = False
        return action_count

    def set_seed(self, seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    def get_ll_config(self):
        print("Fetching LL config")
        return self.env.robot.getPose()
    
    def save_actions(self,logtype,logdir):
        # os.makedirs(os.path.join(self.problem_config['action_logs'],self.problem_config['robot']['name'],self.problem_config['env_name']),exist_ok=True)
        # fname = os.path.join(self.problem_config['action_logs'],
        #                      self.problem_config['robot']['name'],
        #                      self.problem_config['env_name'],
        #                      self.bag_prefix+"_"+logtype+"_"+time.strftime('%d%y%h_%H%M%S')+'.pkl')
        fname = os.path.join(logdir,
                             self.bag_prefix+"_"+logtype+"_"+time.strftime('%d%y%h_%H%M%S')+'.pkl')

        with open(fname, 'wb') as pickleout:
            pickle.dump(self.runbag, pickleout)
    
    def reset_action_log(self):
        self.actions = []
    
    def stash_actions(self):
        self.runbag['run_actions'].extend(self.actions)
        self.reset_action_log()

    def setup_bag(self,init_config, goal_config,prefix=None):
        if prefix is not None:
            self.bag_prefix = prefix
        return {"init_config":init_config,
                "goal_config":goal_config,
                "robot_name":self.problem_config['robot']['name'],
                "robot_file":self.problem_config['robot']['model_path'],
                "env_name":self.problem_config['env_name'],
                "run_actions":[]}
